Skip to content

Commit 3d34693

Browse files
committed
fix: batching
1 parent 74b0c7f commit 3d34693

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

Diff for: src/TracedRArray.jl

+1-16
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@ end
1919

2020
TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x
2121

22-
function Base.setproperty!(x::TracedRArray, f::Symbol, v)
23-
if f === :mlir_data && !isnothing(v)
24-
@assert size(MLIR.IR.type(v)) == size(x)
25-
end
26-
return setfield!(x, f, v)
27-
end
28-
2922
mutable struct TracedRNumber{T} <: RNumber{T}
3023
paths::Tuple
3124
mlir_data::Union{Nothing,MLIR.IR.Value}
@@ -40,13 +33,6 @@ mutable struct TracedRNumber{T} <: RNumber{T}
4033
end
4134
end
4235

43-
function Base.setproperty!(x::TracedRNumber, f::Symbol, v)
44-
if f === :mlir_data && !isnothing(v)
45-
@assert size(MLIR.IR.type(v)) == ()
46-
end
47-
return setfield!(x, f, v)
48-
end
49-
5036
Base.eltype(::Type{TracedRNumber{T}}) where {T} = T
5137

5238
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
@@ -318,7 +304,7 @@ for (jlop, hloop) in (
318304
@eval function $(jlop)(
319305
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
320306
) where {T}
321-
return TracedRArray{T}(
307+
return TracedRNumber{T}(
322308
(),
323309
MLIR.IR.result(
324310
MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1
@@ -430,7 +416,6 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
430416
residx = 1
431417

432418
for a in linear_results
433-
@show a
434419
if has_residx(a)
435420
path = get_residx(a)
436421
set!(result, path[2:end], MLIR.IR.result(res, residx))

Diff for: src/Tracing.jl

+12-4
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ function make_tracer(
332332
end
333333
res = if toscalar
334334
TracedRNumber{T}((path,), nothing)
335-
elseif !isnothing(tobatch)
336-
TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch)
335+
elseif tobatch !== nothing
336+
error("This should not happen...")
337337
else
338338
TracedRArray{T,N}((path,), prev.mlir_data, size(prev))
339339
end
@@ -358,7 +358,9 @@ function make_tracer(
358358
@nospecialize(prev::TracedRNumber{T}),
359359
@nospecialize(path),
360360
mode;
361-
kwargs...
361+
tobatch=nothing,
362+
toscalar=false,
363+
kwargs...,
362364
) where {T}
363365
if mode == ConcreteToTraced
364366
throw("Cannot trace existing trace type")
@@ -374,7 +376,13 @@ function make_tracer(
374376
if haskey(seen, prev)
375377
return seen[prev]
376378
end
377-
res = TracedRNumber{T}((path,), prev.mlir_data)
379+
res = if toscalar
380+
TracedRNumber{T}((path,), nothing)
381+
elseif tobatch !== nothing
382+
TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch)
383+
else
384+
TracedRNumber{T}((path,), prev.mlir_data)
385+
end
378386
seen[prev] = res
379387
return res
380388
end

0 commit comments

Comments
 (0)