Skip to content

Commit 5f4f97e

Browse files
committed
feat: add leaf macro to avoid tracing into types
1 parent 9177412 commit 5f4f97e

File tree

4 files changed

+80
-33
lines changed

4 files changed

+80
-33
lines changed

docs/src/api/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ Reactant.Profiler.@annotate
4646
Reactant.devices
4747
Reactant.addressable_devices
4848
```
49+
50+
## Tracing
51+
52+
```@docs
53+
Reactant.@leaf
54+
```

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue
44

55
using LinearAlgebra: LinearAlgebra
66
using Random: Random, AbstractRNG
7-
using Functors: @leaf
7+
using Functors: Functors
88

99
using Adapt: Adapt, WrappedArray
1010
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`

src/Tracing.jl

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,64 @@
1+
function traced_type_inner end
2+
function make_tracer end
3+
4+
"""
5+
@leaf type [make_tracer = true]
6+
7+
This marks a type as a leaf type for the purposes of tracing in reactant. This means that
8+
we won't recurse into the type and it will be left untouched.
9+
"""
10+
macro leaf(args...)
11+
@assert length(args) 1
12+
orig_type, args = args[1], args[2:end]
13+
14+
options = Dict{Symbol,Any}()
15+
while length(args) 1
16+
if !Meta.isexpr(args[1], :(=))
17+
error("Invalid argument $(args[1])")
18+
end
19+
options[args[1].args[1]] = args[1].args[2]
20+
args = args[2:end]
21+
end
22+
23+
subtype = Meta.isexpr(orig_type, :(<:))
24+
type = subtype ? orig_type.args[1] : orig_type
25+
26+
traced_type_inner_expr = quote
27+
Base.@nospecializeinfer function Reactant.traced_type_inner(
28+
@nospecialize(T::Type{$(orig_type)}),
29+
seen,
30+
@nospecialize(mode::$(TraceMode)),
31+
@nospecialize(track_numbers::Type),
32+
@nospecialize(sharding),
33+
)
34+
return T
35+
end
36+
end
37+
38+
make_tracer_expr = if get(options, :make_tracer, true)
39+
quote
40+
function Reactant.make_tracer(
41+
seen,
42+
@nospecialize(prev::$(type)),
43+
@nospecialize(path),
44+
mode::$(TraceMode);
45+
kwargs...,
46+
)
47+
return prev
48+
end
49+
end
50+
else
51+
:()
52+
end
53+
54+
return esc(
55+
quote
56+
$traced_type_inner_expr
57+
$make_tracer_expr
58+
end,
59+
)
60+
end
61+
162
@enum TraceMode begin
263
ConcreteToTraced = 1
364
TracedTrack = 2
@@ -14,38 +75,27 @@ end
1475

1576
function traced_type_inner end
1677

17-
Base.@nospecializeinfer function traced_type_inner(
18-
@nospecialize(T::Type{Union{}}),
19-
seen,
20-
mode::TraceMode,
21-
@nospecialize(track_numbers::Type),
22-
@nospecialize(sharding)
23-
)
24-
return T
78+
for T in (Symbol, Union{})
79+
@eval begin
80+
@leaf $T make_tracer = false
81+
end
2582
end
2683

2784
for T in (
2885
DataType,
2986
Module,
3087
Nothing,
31-
Symbol,
3288
AbstractChar,
3389
AbstractString,
3490
AbstractFloat,
3591
Integer,
3692
RNumber,
3793
Val,
3894
VersionNumber,
95+
Base.ExceptionStack,
96+
Core.MethodInstance,
3997
)
40-
@eval Base.@nospecializeinfer function traced_type_inner(
41-
@nospecialize(T::Type{<:$T}),
42-
seen,
43-
mode::TraceMode,
44-
@nospecialize(track_numbers::Type),
45-
@nospecialize(sharding)
46-
)
47-
return T
48-
end
98+
@eval @leaf <:$T
4999
end
50100

51101
Base.@nospecializeinfer function traced_type_inner(
@@ -754,15 +804,6 @@ function Base.showerror(io::IO, err::NoFieldMatchError)
754804
end
755805
end
756806

757-
function make_tracer(
758-
seen,
759-
@nospecialize(prev::Union{Base.ExceptionStack,Core.MethodInstance}),
760-
@nospecialize(path),
761-
mode;
762-
kwargs...,
763-
)
764-
return prev
765-
end
766807
append_path(@nospecialize(path), i) = (path..., i)
767808

768809
function make_tracer(

src/Types.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ abstract type AbstractConcreteArray{T,N} <: RArray{T,N} end
99
# Traced Types
1010

1111
## MissingTracedValue -- defined in ReactantCore
12-
@leaf MissingTracedValue
12+
Functors.@leaf MissingTracedValue
1313

1414
## TracedRNumber
1515
mutable struct TracedRNumber{T} <: RNumber{T}
@@ -26,7 +26,7 @@ mutable struct TracedRNumber{T} <: RNumber{T}
2626
end
2727
end
2828

29-
@leaf TracedRNumber
29+
Functors.@leaf TracedRNumber
3030

3131
## TracedRArray
3232
mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
@@ -45,7 +45,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
4545
end
4646
end
4747

48-
@leaf TracedRArray
48+
Functors.@leaf TracedRArray
4949
Adapt.parent_type(::Type{TracedRArray{T,N}}) where {T,N} = TracedRArray{T,N}
5050

5151
const WrappedTracedRArray{T,N} = WrappedArray{
@@ -79,7 +79,7 @@ function ConcretePJRTNumber{T}(data::Tuple{XLA.PJRT.AsyncBuffer}) where {T}
7979
return ConcretePJRTNumber{T,1,Sharding.NoShardInfo}(data, Sharding.NoShardInfo())
8080
end
8181

82-
@leaf ConcretePJRTNumber
82+
Functors.@leaf ConcretePJRTNumber
8383

8484
function ConcretePJRTNumber{T}(data::T2; kwargs...) where {T<:Number,T2<:Number}
8585
carray = ConcretePJRTArray(fill(convert(T, data)); kwargs...)
@@ -105,7 +105,7 @@ mutable struct ConcretePJRTArray{T,N,D,S<:Sharding.ShardInfo} <: AbstractConcret
105105
sharding::S
106106
end
107107

108-
@leaf ConcretePJRTArray
108+
Functors.@leaf ConcretePJRTArray
109109
Adapt.parent_type(::Type{<:ConcretePJRTArray{T,N}}) where {T,N} = ConcretePJRTArray{T,N}
110110
function Adapt.parent_type(::Type{ConcretePJRTArray{T,N,D,S}}) where {T,N,D,S}
111111
return ConcretePJRTArray{T,N,D,S}

0 commit comments

Comments
 (0)