Skip to content

Commit

Permalink
fix handling of experimental module compile flag (#56945)
Browse files Browse the repository at this point in the history
Add a new `finish!` function which skips any inference/optimization and
just directly uses the (uninferred) source as the result, setting all
fields correctly assuming they might have come from a generated function
in the very unlikely case it set some of them, and making sure this is
now correctly synchronized with the cache lookup and insertion calls
once again.

This code feature was added without any tests in #37041, so I cannot
guarantee there aren't any mistakes still lurking here, either mine or
original.

Fixes #53431
  • Loading branch information
vtjnash authored Jan 6, 2025
1 parent 36472a7 commit a23a6de
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 45 deletions.
72 changes: 59 additions & 13 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,39 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
return nothing
end

function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstance, src::CodeInfo)
user_edges = src.edges
edges = user_edges isa SimpleVector ? user_edges : user_edges === nothing ? Core.svec() : Core.svec(user_edges...)
relocatability = 0x1
const_flag = false
di = src.debuginfo
rettype = Any
exctype = Any
rettype_const = nothing
const_flags = 0x0
ipo_effects = zero(UInt32)
min_world = src.min_world
max_world = src.max_world
if max_world >= get_world_counter()
max_world = typemax(UInt)
end
if max_world == typemax(UInt)
# if we can record all of the backedges in the global reverse-cache,
# we can now widen our applicability in the global cache too
store_backedges(ci, edges)
end
ccall(:jl_fill_codeinst, Cvoid, (Any, Any, Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
ci, rettype, exctype, nothing, const_flags, min_world, max_world, ipo_effects, nothing, di, edges)
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, UInt8, Any, Any),
ci, nothing, const_flag, min_world, max_world, ipo_effects, nothing, relocatability, di, edges)
code_cache(interp)[mi] = ci
if isdefined(interp, :codegen)
interp.codegen[ci] = src
end
engine_reject(interp, ci)
return nothing
end

function finish_nocycle(::AbstractInterpreter, frame::InferenceState)
finishinfer!(frame, frame.interp)
opt = frame.result.src
Expand Down Expand Up @@ -826,7 +859,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
end
end
end
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 && !generating_output(#=incremental=#false)
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0
add_remark!(interp, caller, "[typeinf_edge] Inference is disabled for the target module")
return Future(MethodCallResult(interp, caller, method, Any, Any, Effects(), nothing, edgecycle, edgelimited))
end
Expand Down Expand Up @@ -1096,15 +1129,6 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance, source_mod
end
end
def = mi.def
if isa(def, Method)
if ccall(:jl_get_module_infer, Cint, (Any,), def.module) == 0 && !generating_output(#=incremental=#false)
src = retrieve_code_info(mi, get_inference_world(interp))
src isa CodeInfo || return nothing
return CodeInstance(mi, cache_owner(interp), Any, Any, nothing, src, Int32(0),
get_inference_world(interp), get_inference_world(interp),
UInt32(0), nothing, UInt8(0), src.debuginfo, src.edges)
end
end
ci = engine_reserve(interp, mi)
# check cache again if it is still new after reserving in the engine
let code = get(code_cache(interp), mi, nothing)
Expand All @@ -1117,11 +1141,22 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance, source_mod
end
end
end
if isa(def, Method) && ccall(:jl_get_module_infer, Cint, (Any,), def.module) == 0
src = retrieve_code_info(mi, get_inference_world(interp))
if src isa CodeInfo
finish!(interp, mi, ci, src)
else
engine_reject(interp, ci)
end
ccall(:jl_typeinf_timing_end, Cvoid, (UInt64,), start_time)
return ci
end
result = InferenceResult(mi, typeinf_lattice(interp))
result.ci = ci
frame = InferenceState(result, #=cache_mode=#:global, interp)
if frame === nothing
engine_reject(interp, ci)
ccall(:jl_typeinf_timing_end, Cvoid, (UInt64,), start_time)
return nothing
end
typeinf(interp, frame)
Expand Down Expand Up @@ -1263,18 +1298,29 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
callee in inspected && continue
push!(inspected, callee)
# now make sure everything has source code, if desired
# TODO: typeinf_code could return something with different edges/ages (needing an update to callee), which we don't handle here
mi = get_ci_mi(callee)
def = mi.def
if use_const_api(callee)
src = codeinfo_for_const(interp, callee.def, code.rettype_const)
src = codeinfo_for_const(interp, mi, code.rettype_const)
elseif haskey(interp.codegen, callee)
src = interp.codegen[callee]
elseif isa(def, Method) && ccall(:jl_get_module_infer, Cint, (Any,), def.module) == 0 && !trim
src = retrieve_code_info(mi, get_inference_world(interp))
else
src = typeinf_code(interp, callee.def, true)
# TODO: typeinf_code could return something with different edges/ages/owner/abi (needing an update to callee), which we don't handle here
src = typeinf_code(interp, mi, true)
end
if src isa CodeInfo
collectinvokes!(tocompile, src)
# It is somewhat ambiguous if typeinf_ext might have callee in the caches,
# but for the purpose of native compile, we always want them put there.
if iszero(ccall(:jl_mi_cache_has_ci, Cint, (Any, Any), mi, callee))
code_cache(interp)[mi] = callee
end
push!(codeinfos, callee)
push!(codeinfos, src)
elseif trim
println("warning: failed to get code for ", mi)
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
}
}
else if (params.params->trim) {
jl_safe_printf("warning: no code provided for function");
jl_safe_printf("warning: no code provided for function ");
jl_(codeinst->def);
if (params.params->trim)
abort();
Expand All @@ -441,7 +441,7 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
Function *pinvoke = nullptr;
if (preal_decl.empty()) {
if (invokeName.empty() && params.params->trim) {
jl_safe_printf("warning: bailed out to invoke when compiling:");
jl_safe_printf("warning: bailed out to invoke when compiling: ");
jl_(codeinst->def);
abort();
}
Expand Down Expand Up @@ -658,7 +658,7 @@ void *jl_emit_native_impl(jl_array_t *codeinfos, LLVMOrcThreadSafeModuleRef llvm
else if (params.params->trim) {
// if we're building a small image, we need to compile everything
// to ensure that we have all the information we need.
jl_safe_printf("codegen failed to compile code root");
jl_safe_printf("codegen failed to compile code root ");
jl_(mi);
abort();
}
Expand Down
27 changes: 27 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,33 @@ JL_DLLEXPORT int jl_mi_cache_has_ci(jl_method_instance_t *mi,
return 0;
}

// look for something with an egal ABI and properties that is already in the JIT (compiled=true) or simply in the cache (compiled=false)
JL_DLLEXPORT jl_code_instance_t *jl_get_ci_equiv(jl_code_instance_t *ci JL_PROPAGATES_ROOT, int compiled) JL_NOTSAFEPOINT
{
jl_value_t *def = ci->def;
jl_method_instance_t *mi = jl_get_ci_mi(ci);
jl_value_t *owner = ci->owner;
jl_value_t *rettype = ci->rettype;
size_t min_world = jl_atomic_load_relaxed(&ci->min_world);
size_t max_world = jl_atomic_load_relaxed(&ci->max_world);
jl_code_instance_t *codeinst = jl_atomic_load_relaxed(&mi->cache);
while (codeinst) {
if (codeinst != ci &&
jl_atomic_load_relaxed(&codeinst->inferred) != NULL &&
(!compiled || jl_atomic_load_relaxed(&codeinst->invoke) != NULL) &&
jl_atomic_load_relaxed(&codeinst->min_world) <= min_world &&
jl_atomic_load_relaxed(&codeinst->max_world) >= max_world &&
jl_egal(codeinst->def, def) &&
jl_egal(codeinst->owner, owner) &&
jl_egal(codeinst->rettype, rettype)) {
return codeinst;
}
codeinst = jl_atomic_load_relaxed(&codeinst->next);
}
return (jl_code_instance_t*)jl_nothing;
}


JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst(
jl_method_instance_t *mi, jl_value_t *owner,
jl_value_t *rettype, jl_value_t *exctype,
Expand Down
30 changes: 2 additions & 28 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,32 +301,6 @@ static void finish_params(Module *M, jl_codegen_params_t &params) JL_NOTSAFEPOIN
}
}

// look for something with an egal ABI that is already in the JIT
static jl_code_instance_t *jl_method_compiled_egal(jl_code_instance_t *ci JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT
{
jl_value_t *def = ci->def;
jl_method_instance_t *mi = jl_get_ci_mi(ci);
jl_value_t *owner = ci->owner;
jl_value_t *rettype = ci->rettype;
size_t min_world = jl_atomic_load_relaxed(&ci->min_world);
size_t max_world = jl_atomic_load_relaxed(&ci->max_world);
jl_code_instance_t *codeinst = jl_atomic_load_relaxed(&mi->cache);
while (codeinst) {
if (codeinst != ci &&
jl_atomic_load_relaxed(&codeinst->inferred) != NULL &&
jl_atomic_load_relaxed(&codeinst->invoke) != NULL &&
jl_atomic_load_relaxed(&codeinst->min_world) <= min_world &&
jl_atomic_load_relaxed(&codeinst->max_world) >= max_world &&
jl_egal(codeinst->def, def) &&
jl_egal(codeinst->owner, owner) &&
jl_egal(codeinst->rettype, rettype)) {
return codeinst;
}
codeinst = jl_atomic_load_relaxed(&codeinst->next);
}
return codeinst;
}

static int jl_analyze_workqueue(jl_code_instance_t *callee, jl_codegen_params_t &params, bool forceall=false) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER
{
jl_task_t *ct = jl_current_task;
Expand Down Expand Up @@ -377,8 +351,8 @@ static int jl_analyze_workqueue(jl_code_instance_t *callee, jl_codegen_params_t
}
if (preal_decl.empty()) {
// there may be an equivalent method already compiled (or at least registered with the JIT to compile), in which case we should be using that instead
jl_code_instance_t *compiled_ci = jl_method_compiled_egal(codeinst);
if (compiled_ci) {
jl_code_instance_t *compiled_ci = jl_get_ci_equiv(codeinst, 1);
if ((jl_value_t*)compiled_ci != jl_nothing) {
codeinst = compiled_ci;
uint8_t specsigflags;
void *fptr;
Expand Down
2 changes: 1 addition & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst(
int32_t const_flags, size_t min_world, size_t max_world,
uint32_t effects, jl_value_t *analysis_results,
uint8_t relocatability, jl_debuginfo_t *di, jl_svec_t *edges /* , int absolute_max*/);
JL_DLLEXPORT jl_code_instance_t *jl_get_ci_equiv(jl_code_instance_t *ci JL_PROPAGATES_ROOT, int compiled) JL_NOTSAFEPOINT;

STATIC_INLINE jl_method_instance_t *jl_get_ci_mi(jl_code_instance_t *ci JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT
{
Expand Down Expand Up @@ -1221,7 +1222,6 @@ JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMEN
JL_DLLEXPORT int jl_mi_try_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
jl_code_instance_t *expected_ci,
jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED);
JL_DLLEXPORT int jl_mi_cache_has_ci(jl_method_instance_t *mi, jl_code_instance_t *ci) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_code_instance_t *jl_cached_uninferred(jl_code_instance_t *codeinst, size_t world);
JL_DLLEXPORT jl_code_instance_t *jl_cache_uninferred(jl_method_instance_t *mi, jl_code_instance_t *checked, size_t world, jl_code_instance_t *newci JL_MAYBE_UNROOTED);
JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst_for_uninferred(jl_method_instance_t *mi, jl_code_info_t *src);
Expand Down

0 comments on commit a23a6de

Please sign in to comment.