Skip to content

Commit 023f97a

Browse files
committed
feat: partial progress on getting scalars to work
1 parent 981c93a commit 023f97a

File tree

4 files changed

+111
-108
lines changed

4 files changed

+111
-108
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,24 @@
11
module ReactantNNlibExt
22

33
using NNlib
4-
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR
4+
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR,
5+
TracedRScalar
56

67
for (jlop, hloop) in (
78
(:(NNlib.tanh_fast), :tanh),
89
(:(NNlib.sigmoid_fast), :logistic),
910
(:(NNlib.sigmoid), :logistic),
1011
)
11-
@eval function $(jlop)(x::TracedRArray{T,0}) where {T}
12-
return TracedRArray{T,0}(
12+
@eval function $(jlop)(x::TracedRScalar{T}) where {T}
13+
return TracedRScalar{T}(
1314
(),
1415
Reactant.MLIR.IR.result(
1516
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
1617
),
17-
(),
1818
)
1919
end
2020
end
2121

22-
# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays
23-
for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, ))
24-
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
25-
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
26-
end
27-
end
28-
2922
# TODO handle non finite cases
3023
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
3124
max_ = NNlib.fast_maximum(x; dims)

src/TracedRArray.jl

Lines changed: 63 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1717
end
1818
end
1919

20+
function Base.setproperty!(x::TracedRArray, f::Symbol, v)
21+
if f === :mlir_data && !isnothing(v)
22+
@assert size(MLIR.IR.type(v)) == size(x)
23+
end
24+
return setfield!(x, f, v)
25+
end
26+
2027
mutable struct TracedRScalar{T} <: RScalar{T}
2128
paths::Tuple
2229
mlir_data::Union{Nothing,MLIR.IR.Value}
@@ -31,6 +38,15 @@ mutable struct TracedRScalar{T} <: RScalar{T}
3138
end
3239
end
3340

41+
function Base.setproperty!(x::TracedRScalar, f::Symbol, v)
42+
if f === :mlir_data && !isnothing(v)
43+
@assert size(MLIR.IR.type(v)) == ()
44+
end
45+
return setfield!(x, f, v)
46+
end
47+
48+
Base.eltype(::Type{TracedRScalar{T}}) where {T} = T
49+
3450
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
3551
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
3652
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
@@ -57,7 +73,7 @@ Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
5773
Base.one(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, one(T))
5874

5975
function Base.convert(::Type{<:TracedRScalar{T}}, x::Number) where {T}
60-
return promote_to(TracedRArray{T,0}, T(x))
76+
return promote_to(TracedRScalar{T}, T(x))
6177
end
6278

6379
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
@@ -119,7 +135,7 @@ function Base.setindex!(
119135
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
120136
) where {T,N}
121137
indices = [
122-
(promote_to(TracedRArray{Int,0}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
138+
(promote_to(TracedRScalar{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
123139
i in indices
124140
]
125141
v = promote_to(TracedRArray{T,N}, v)
@@ -220,6 +236,14 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
220236
return TracedRArray{Base.promote_type(T, S),N}
221237
end
222238

239+
function Base.promote_rule(::Type{T}, ::Type{TracedRScalar{S}}) where {T,S}
240+
return TracedRScalar{Base.promote_type(T, S)}
241+
end
242+
243+
function Base.convert(::Type{TracedRScalar{T}}, x::Number) where {T}
244+
return promote_to(TracedRScalar{T}, x)
245+
end
246+
223247
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
224248
if isa(rhs, TracedRArray)
225249
rhs isa TracedRArray{T,N} && return rhs
@@ -277,12 +301,8 @@ function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
277301
)
278302
end
279303

280-
function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
281-
return promote_to(TracedRArray{T,N}, rhs)
282-
end
283-
function promote_to(::TracedRScalar{T}, rhs) where {T}
284-
return promote_to(TracedRScalar{T}, rhs)
285-
end
304+
promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs)
305+
promote_to(::TracedRScalar{T}, rhs) where {T} = promote_to(TracedRScalar{T}, rhs)
286306

287307
for (jlop, hloop) in (
288308
(:(Base.min), :minimum),
@@ -293,66 +313,35 @@ for (jlop, hloop) in (
293313
(:(Base.:/), :divide),
294314
(:(Base.:^), :power),
295315
)
296-
@eval begin
297-
function $(jlop)(
298-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0})
299-
) where {T}
300-
return TracedRArray{T,0}(
301-
(),
302-
MLIR.IR.result(
303-
MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1
304-
),
305-
(),
306-
)
307-
end
308-
309-
function $(jlop)(
310-
@nospecialize(lhs::TracedRArray{T1,0}), @nospecialize(rhs::TracedRArray{T2,0})
311-
) where {T1,T2}
312-
commonTy = TracedRArray{Base.promote_type(T1, T2),0}
313-
lhs = promote_to(commonTy, lhs)
314-
rhs = promote_to(commonTy, rhs)
315-
return $(jlop)(lhs, rhs)
316-
end
317-
end
318-
319-
for otherType in (Number, Any)
320-
@eval begin
321-
function $(jlop)(
322-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::$(otherType))
323-
) where {T}
324-
rhs = promote_to(lhs, rhs)
325-
return $(jlop)(lhs, rhs)
326-
end
327-
328-
function $(jlop)(
329-
@nospecialize(lhs::$(otherType)), @nospecialize(rhs::TracedRArray{T,0})
330-
) where {T}
331-
lhs = promote_to(rhs, lhs)
332-
return $(jlop)(lhs, rhs)
333-
end
334-
end
316+
@eval function $(jlop)(
317+
@nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T})
318+
) where {T}
319+
return TracedRArray{T}(
320+
(),
321+
MLIR.IR.result(
322+
MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1
323+
),
324+
)
335325
end
336326
end
337327

338328
function Base.ifelse(
339-
@nospecialize(pred::TracedRArray{Bool,0}),
340-
@nospecialize(x::TracedRArray{T1,0}),
341-
@nospecialize(y::TracedRArray{T2,0})
329+
@nospecialize(pred::TracedRScalar{Bool}),
330+
@nospecialize(x::TracedRScalar{T1}),
331+
@nospecialize(y::TracedRScalar{T2})
342332
) where {T1,T2}
343-
return TracedRArray{promote_type(T1, T2),0}(
333+
return TracedRScalar{promote_type(T1, T2)}(
344334
(),
345335
MLIR.IR.result(
346336
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
347337
),
348-
size(pred),
349338
)
350339
end
351340

352-
Base.abs2(x::Reactant.TracedRArray{T,0}) where {T} = x * conj(x)
341+
Base.abs2(x::Reactant.TracedRScalar{T}) where {T} = x * conj(x)
353342

354343
function Base.literal_pow(
355-
::Base.RefValue{typeof(^)}, x::TracedRArray{T,0}, ::Base.RefValue{Val{P}}
344+
::Base.RefValue{typeof(^)}, x::TracedRScalar{T}, ::Base.RefValue{Val{P}}
356345
) where {T,P}
357346
return Base.literal_pow(^, x, Val(P))
358347
end
@@ -369,14 +358,10 @@ for (jlop, hloop) in (
369358
(:(Base.log), :log),
370359
(:(Base.sqrt), :sqrt),
371360
)
372-
@eval begin
373-
function $jlop(@nospecialize(lhs::TracedRArray{T,0})) where {T}
374-
return TracedRArray{T,0}(
375-
(),
376-
MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1),
377-
size(lhs),
378-
)
379-
end
361+
@eval function $(jlop)(@nospecialize(lhs::TracedRScalar{T})) where {T}
362+
return TracedRScalar{T}(
363+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1)
364+
)
380365
end
381366
end
382367

@@ -443,6 +428,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
443428
residx = 1
444429

445430
for a in linear_results
431+
@show a
446432
if has_residx(a)
447433
path = get_residx(a)
448434
set!(result, path[2:end], MLIR.IR.result(res, residx))
@@ -478,37 +464,22 @@ for (jlop, hloop, hlocomp, merge) in (
478464
(:(Base.:(<=)), :compare, "LE", nothing),
479465
(:(Base.:(<)), :compare, "LT", nothing),
480466
)
481-
@eval begin
482-
function $(jlop)(
483-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0})
484-
) where {T}
485-
return TracedRArray{Bool,0}(
486-
(),
487-
MLIR.IR.result(
488-
MLIR.Dialects.stablehlo.$hloop(
489-
lhs.mlir_data,
490-
rhs.mlir_data;
491-
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
492-
MLIR.IR.context(), $hlocomp
493-
),
467+
@eval function $(jlop)(
468+
@nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T})
469+
) where {T}
470+
return TracedRScalar{Bool}(
471+
(),
472+
MLIR.IR.result(
473+
MLIR.Dialects.stablehlo.$(hloop)(
474+
lhs.mlir_data,
475+
rhs.mlir_data;
476+
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
477+
MLIR.IR.context(), $hlocomp
494478
),
495-
1,
496479
),
497-
size(lhs),
498-
)
499-
end
500-
501-
function $(jlop)(
502-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs)
503-
) where {T}
504-
return $(jlop)(lhs, promote_to(lhs, rhs))
505-
end
506-
507-
function $(jlop)(
508-
@nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,0})
509-
) where {T}
510-
return $(jlop)(promote_to(rhs, lhs), rhs)
511-
end
480+
1,
481+
),
482+
)
512483
end
513484

514485
if merge !== nothing
@@ -598,7 +569,7 @@ function Base.mapreduce(
598569
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys])
599570

600571
args = (
601-
TracedRArray{T,0}((), MLIR.IR.argument(fnbody, i), ()) for
572+
TracedRScalar{T}((), MLIR.IR.argument(fnbody, i), ()) for
602573
(i, ty) in enumerate(in_tys)
603574
)
604575

src/Tracing.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ for T in (
1616
Integer,
1717
AbstractString,
1818
RArray,
19+
RScalar,
1920
)
2021
@eval function traced_type(::Type{T}, seen, mode) where {T<:$T}
2122
return T
@@ -330,7 +331,7 @@ function make_tracer(
330331
return seen[prev]
331332
end
332333
res = if toscalar
333-
TracedRArray{T,0}((path,), nothing, ())
334+
TracedRScalar{T}((path,), nothing)
334335
elseif !isnothing(tobatch)
335336
TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch)
336337
else
@@ -352,6 +353,44 @@ function make_tracer(
352353
throw("Cannot Unknown trace mode $mode")
353354
end
354355

356+
function make_tracer(
357+
seen,
358+
@nospecialize(prev::TracedRScalar{T}),
359+
@nospecialize(path),
360+
mode;
361+
kwargs...
362+
) where {T}
363+
if mode == ConcreteToTraced
364+
throw("Cannot trace existing trace type")
365+
end
366+
if mode == TracedTrack
367+
prev.paths = (prev.paths..., path)
368+
if !haskey(seen, prev)
369+
return seen[prev] = prev
370+
end
371+
return prev
372+
end
373+
if mode == TracedSetPath
374+
if haskey(seen, prev)
375+
return seen[prev]
376+
end
377+
res = TracedRScalar{T}((path,), prev.mlir_data)
378+
seen[prev] = res
379+
return res
380+
end
381+
382+
if mode == TracedToConcrete
383+
if haskey(seen, prev)
384+
return seen[prev]::ConcreteRArray{T,0}
385+
end
386+
res = ConcreteRArray{T,0}(XLA.AsyncEmptyBuffer, size(prev))
387+
seen[prev] = res
388+
return res
389+
end
390+
391+
throw("Cannot Unknown trace mode $mode")
392+
end
393+
355394
function make_tracer(
356395
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
357396
) where {RT<:AbstractFloat}

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
4444
)
4545
end
4646

47-
linear_args = TracedRArray[]
47+
linear_args = Union{TracedRArray,TracedRScalar}[]
4848
for (k, v) in seen_args
49-
if !(v isa TracedRArray)
49+
if !(v isa TracedRArray) && !(v isa TracedRScalar)
5050
continue
5151
end
5252
push!(linear_args, v)
@@ -127,10 +127,10 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
127127
)
128128
end
129129

130-
linear_results = TracedRArray[]
130+
linear_results = Union{TracedRArray,TracedRScalar}[]
131131

132132
for (k, v) in seen_results
133-
if !(v isa TracedRArray)
133+
if !(v isa TracedRArray) && !(v isa TracedRScalar)
134134
continue
135135
end
136136

0 commit comments

Comments
 (0)