Skip to content

Commit

Permalink
within_tracing, approach taken from Enzyme.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Dec 17, 2024
1 parent 55d1527 commit 566e809
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 14 deletions.
20 changes: 12 additions & 8 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ module ReactantCore
using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

using ScopedValues
const enable_tracing = ScopedValue{Bool}(false)

export @trace, MissingTracedValue
export @trace, within_tracing, MissingTracedValue

# Traits
is_traced(x) = false
Expand All @@ -22,6 +19,13 @@ const SPECIAL_SYMBOLS = [
:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core
]

"""
within_tracing()
Returns true if within tracing, otherwise false.
"""
@inline within_tracing() = false # behavior is overwritten in Interpreter.jl

# Code generation
"""
@trace <expr>
Expand Down Expand Up @@ -186,7 +190,7 @@ function trace_for(mod, expr)
end

return quote
if $(enable_tracing)[] && $(any)($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
if $(within_tracing)() && $(any)($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
$(reactant_code_block)
else
$(expr)
Expand All @@ -200,7 +204,7 @@ function trace_if_with_returns(mod, expr)
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
)
return quote
if $(enable_tracing)[] && $(any)($(is_traced), ($(all_check_vars...),))
if $(within_tracing)() && $(any)($(is_traced), ($(all_check_vars...),))
$(new_expr)
else
$(expr)
Expand Down Expand Up @@ -346,7 +350,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
)

return quote
if $(enable_tracing)[] && $(any)($(is_traced), ($(all_check_vars...),))
if $(within_tracing)() && $(any)($(is_traced), ($(all_check_vars...),))
$(reactant_code_block)
else
$(original_expr)
Expand All @@ -358,7 +362,7 @@ function trace_call(mod, expr)
f = expr.args[1]
args = expr.args[2:end]
return quote
if $(enable_tracing)[]
if $(within_tracing)()
$(traced_call)($f, $(args...))
else
$(expr)
Expand Down
5 changes: 1 addition & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import ..Reactant:
TracedType,
Cached
using ScopedValues
import ReactantCore: enable_tracing

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
Expand Down Expand Up @@ -297,9 +296,7 @@ function compile_mlir!(mod, f, args, callcache; optimize::Union{Bool,Symbol}=tru
linear_results = MLIR.IR.mmodule!(mod) do
MLIR.IR.block!(MLIR.IR.body(mod)) do
callcache!(callcache) do
with(enable_tracing=>true) do
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
end
end
Expand Down
24 changes: 24 additions & 0 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ function set_reactant_abi(
)
(; fargs, argtypes) = arginfo

if f === ReactantCore.within_tracing
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
end
end
@static if VERSION < v"1.11.0-"
return CallMeta(
Core.Const(true),
Core.Compiler.EFFECTS_TOTAL,
MethodResultPure(),
)
else
return CallMeta(
Core.Const(true),
Union{},
Core.Compiler.EFFECTS_TOTAL,
MethodResultPure(),
)
end
end

# Improve inference by considering call_with_reactant as having the same results as
# the original call
if f === Reactant.call_with_reactant
Expand Down
4 changes: 2 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Reactant

using ReactantCore: ReactantCore, @trace, MissingTracedValue
using ReactantCore: ReactantCore, @trace, within_tracing, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Adapt: Adapt, WrappedArray
Expand Down Expand Up @@ -145,7 +145,7 @@ function Enzyme.make_zero(
end

using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace, within_tracing

const registry = Ref{MLIR.IR.DialectRegistry}()
function __init__()
Expand Down

0 comments on commit 566e809

Please sign in to comment.