Skip to content

Commit 209001e

Browse files
committed
feat: add leaf macro to avoid tracing into types
1 parent 9177412 commit 209001e

File tree

4 files changed

+75
-33
lines changed

4 files changed

+75
-33
lines changed

docs/src/api/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ Reactant.Profiler.@annotate
4646
Reactant.devices
4747
Reactant.addressable_devices
4848
```
49+
50+
## Tracing
51+
52+
```@docs
53+
Reactant.@leaf
54+
```

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue
44

55
using LinearAlgebra: LinearAlgebra
66
using Random: Random, AbstractRNG
7-
using Functors: @leaf
7+
using Functors: Functors
88

99
using Adapt: Adapt, WrappedArray
1010
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

src/Tracing.jl

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,61 @@
1+
"""
2+
@leaf type [make_tracer = true]
3+
4+
This marks a type as a leaf type for the purposes of tracing in reactant. This means that
5+
we won't recurse into the type and it will be left untouched.
6+
"""
7+
macro leaf(args...)
8+
@assert length(args) 1
9+
orig_type, args = args[1], args[2:end]
10+
11+
options = Dict{Symbol,Any}()
12+
while length(args) 1
13+
if !Meta.isexpr(args[1], :(=))
14+
error("Invalid argument $(args[1])")
15+
end
16+
options[args[1].args[1]] = args[1].args[2]
17+
args = args[2:end]
18+
end
19+
20+
subtype = Meta.isexpr(orig_type, :(<:))
21+
type = subtype ? orig_type.args[1] : orig_type
22+
23+
traced_type_inner_expr = quote
24+
Base.@nospecializeinfer function Reactant.traced_type_inner(
25+
@nospecialize(T::Type{$(orig_type)}),
26+
seen,
27+
@nospecialize(mode::$(TraceMode)),
28+
@nospecialize(track_numbers::Type),
29+
@nospecialize(sharding),
30+
)
31+
return T
32+
end
33+
end
34+
35+
make_tracer_expr = if get(options, :make_tracer, true)
36+
quote
37+
function Reactant.make_tracer(
38+
seen,
39+
@nospecialize(prev::$(type)),
40+
@nospecialize(path),
41+
mode::$(TraceMode);
42+
kwargs...,
43+
)
44+
return prev
45+
end
46+
end
47+
else
48+
:()
49+
end
50+
51+
return esc(
52+
quote
53+
$traced_type_inner_expr
54+
$make_tracer_expr
55+
end,
56+
)
57+
end
58+
159
@enum TraceMode begin
260
ConcreteToTraced = 1
361
TracedTrack = 2
@@ -14,38 +72,25 @@ end
1472

1573
function traced_type_inner end
1674

17-
Base.@nospecializeinfer function traced_type_inner(
18-
@nospecialize(T::Type{Union{}}),
19-
seen,
20-
mode::TraceMode,
21-
@nospecialize(track_numbers::Type),
22-
@nospecialize(sharding)
23-
)
24-
return T
75+
for T in (Symbol, Union{})
76+
@eval @leaf $T make_tracer = false
2577
end
2678

2779
for T in (
2880
DataType,
2981
Module,
3082
Nothing,
31-
Symbol,
3283
AbstractChar,
3384
AbstractString,
3485
AbstractFloat,
3586
Integer,
3687
RNumber,
3788
Val,
3889
VersionNumber,
90+
Base.ExceptionStack,
91+
Core.MethodInstance,
3992
)
40-
@eval Base.@nospecializeinfer function traced_type_inner(
41-
@nospecialize(T::Type{<:$T}),
42-
seen,
43-
mode::TraceMode,
44-
@nospecialize(track_numbers::Type),
45-
@nospecialize(sharding)
46-
)
47-
return T
48-
end
93+
@eval @leaf <:$T
4994
end
5095

5196
Base.@nospecializeinfer function traced_type_inner(
@@ -754,15 +799,6 @@ function Base.showerror(io::IO, err::NoFieldMatchError)
754799
end
755800
end
756801

757-
function make_tracer(
758-
seen,
759-
@nospecialize(prev::Union{Base.ExceptionStack,Core.MethodInstance}),
760-
@nospecialize(path),
761-
mode;
762-
kwargs...,
763-
)
764-
return prev
765-
end
766802
append_path(@nospecialize(path), i) = (path..., i)
767803

768804
function make_tracer(

src/Types.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ abstract type AbstractConcreteArray{T,N} <: RArray{T,N} end
99
# Traced Types
1010

1111
## MissingTracedValue -- defined in ReactantCore
12-
@leaf MissingTracedValue
12+
Functors.@leaf MissingTracedValue
1313

1414
## TracedRNumber
1515
mutable struct TracedRNumber{T} <: RNumber{T}
@@ -26,7 +26,7 @@ mutable struct TracedRNumber{T} <: RNumber{T}
2626
end
2727
end
2828

29-
@leaf TracedRNumber
29+
Functors.@leaf TracedRNumber
3030

3131
## TracedRArray
3232
mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
@@ -45,7 +45,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
4545
end
4646
end
4747

48-
@leaf TracedRArray
48+
Functors.@leaf TracedRArray
4949
Adapt.parent_type(::Type{TracedRArray{T,N}}) where {T,N} = TracedRArray{T,N}
5050

5151
const WrappedTracedRArray{T,N} = WrappedArray{
@@ -79,7 +79,7 @@ function ConcretePJRTNumber{T}(data::Tuple{XLA.PJRT.AsyncBuffer}) where {T}
7979
return ConcretePJRTNumber{T,1,Sharding.NoShardInfo}(data, Sharding.NoShardInfo())
8080
end
8181

82-
@leaf ConcretePJRTNumber
82+
Functors.@leaf ConcretePJRTNumber
8383

8484
function ConcretePJRTNumber{T}(data::T2; kwargs...) where {T<:Number,T2<:Number}
8585
carray = ConcretePJRTArray(fill(convert(T, data)); kwargs...)
@@ -105,7 +105,7 @@ mutable struct ConcretePJRTArray{T,N,D,S<:Sharding.ShardInfo} <: AbstractConcret
105105
sharding::S
106106
end
107107

108-
@leaf ConcretePJRTArray
108+
Functors.@leaf ConcretePJRTArray
109109
Adapt.parent_type(::Type{<:ConcretePJRTArray{T,N}}) where {T,N} = ConcretePJRTArray{T,N}
110110
function Adapt.parent_type(::Type{ConcretePJRTArray{T,N,D,S}}) where {T,N,D,S}
111111
return ConcretePJRTArray{T,N,D,S}

0 commit comments

Comments
 (0)